# util/_collections_cy.py
# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: disable-error-code="misc, no-any-return, no-untyped-def, override"
# mypy: disable-error-code="untyped-decorator"

from __future__ import annotations

from typing import AbstractSet
from typing import Any
from typing import Dict
from typing import Hashable
from typing import Iterable
from typing import Iterator
from typing import List
from typing import NoReturn
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union

from .typing import Self

# START GENERATED CYTHON IMPORT
# This section is automatically generated by the script tools/cython_imports.py
try:
    # NOTE: the cython compiler needs this "import cython" in the file, it
    # can't be only "from sqlalchemy.util import cython" with the fallback
    # in that module
    import cython
except ModuleNotFoundError:
    from sqlalchemy.util import cython


def _is_compiled() -> bool:
    """Utility function to indicate if this module is compiled or not."""
    return cython.compiled  # type: ignore[no-any-return,unused-ignore]


# END GENERATED CYTHON IMPORT
_T = TypeVar("_T")
_S = TypeVar("_S")


@cython.ccall
def unique_list(seq: Iterable[_T]) -> List[_T]:
    # this version seems somewhat faster for smaller sizes, but it's
    # significantly slower on larger sizes
    # w = {x:None for x in seq}
    # return PyDict_Keys(w) if cython.compiled else list(w)
    if cython.compiled:
        seen: Set[_T] = set()
        return [x for x in seq if x not in seen and not set.add(seen, x)]
    else:
        return list(dict.fromkeys(seq))

    # In case passing an hashfunc is required in the future two version were
    # tested:
    # - this version is faster but returns the *last* element matching the
    #   hash.
    #   from cython.cimports.cpython.dict import PyDict_Values
    #   w: dict = {hashfunc(x): x for x in seq}
    #   return PyDict_Values(w) if cython.compiled else list(w.values())
    # - this version is slower but returns the *first* element matching the
    #   hash.
    #   seen: set = set()
    #   res: list = []
    #   for x in seq:
    #       h = hashfunc(x)
    #       if h not in seen:
    #           res.append(x)
    #           seen.add(h)
    #   return res


@cython.cclass
class OrderedSet(Set[_T]):
    """A set implementation that maintains insertion order."""

    __slots__ = ("_list",)
    _list: List[_T]

    @classmethod
    def __class_getitem__(cls, key: Any) -> type[Self]:
        return cls

    def __init__(self, d: Optional[Iterable[_T]] = None) -> None:
        if d is not None:
            if isinstance(d, set) or isinstance(d, dict):
                self._list = list(d)
            else:
                self._list = unique_list(d)
            set.__init__(self, self._list)
        else:
            self._list = []
            set.__init__(self)

    def copy(self) -> OrderedSet[_T]:
        return self._from_list(list(self._list))

    @cython.final
    @cython.cfunc
    @cython.inline
    def _from_list(self, new_list: List[_T]) -> OrderedSet:  # type: ignore[type-arg] # noqa: E501
        new: OrderedSet = OrderedSet.__new__(OrderedSet)  # type: ignore[type-arg] # noqa: E501
        new._list = new_list
        set.update(new, new_list)
        return new

    def add(self, element: _T, /) -> None:
        if element not in self:
            self._list.append(element)
            set.add(self, element)

    def remove(self, element: _T, /) -> None:
        # set.remove will raise if element is not in self
        set.remove(self, element)
        self._list.remove(element)

    def pop(self) -> _T:
        try:
            value = self._list.pop()
        except IndexError:
            raise KeyError("pop from an empty set") from None
        set.remove(self, value)
        return value

    def insert(self, pos: cython.Py_ssize_t, element: _T, /) -> None:
        if element not in self:
            self._list.insert(pos, element)
            set.add(self, element)

    def discard(self, element: _T, /) -> None:
        if element in self:
            set.remove(self, element)
            self._list.remove(element)

    def clear(self) -> None:
        set.clear(self)
        self._list = []

    def __getitem__(self, key: cython.Py_ssize_t) -> _T:
        return self._list[key]

    def __iter__(self) -> Iterator[_T]:
        return iter(self._list)

    def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]:
        return self.union(other)

    def __repr__(self) -> str:
        return "%s(%r)" % (self.__class__.__name__, self._list)

    __str__ = __repr__

    # @cython.ccall # cdef function cannot have star argument
    def update(self, *iterables: Iterable[_T]) -> None:
        for iterable in iterables:
            for element in iterable:
                # inline of add. mainly for python, since for cython we
                # could create an @cfunc @inline _add function that would
                # perform the same
                if element not in self:
                    self._list.append(element)
                    set.add(self, element)

    def __ior__(
        self: OrderedSet[Union[_T, _S]], iterable: AbstractSet[_S]
    ) -> OrderedSet[Union[_T, _S]]:
        self.update(iterable)
        return self

    # @cython.ccall # cdef function cannot have star argument
    def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
        result: OrderedSet[Union[_T, _S]] = self._from_list(list(self._list))
        result.update(*other)
        return result

    def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
        return self.union(other)

    # @cython.ccall # cdef function cannot have star argument
    def intersection(self, *other: Iterable[Hashable]) -> OrderedSet[_T]:
        other_set: Set[Any] = set.intersection(self, *other)
        return self._from_list([a for a in self._list if a in other_set])

    def __and__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]:
        return self.intersection(other)

    @cython.ccall
    @cython.annotation_typing(False)  # avoid cython crash from generic return
    def symmetric_difference(
        self, other: Iterable[_S], /
    ) -> OrderedSet[Union[_T, _S]]:
        collection: Iterable[Any]
        other_set: Set[_S]
        if isinstance(other, set):
            other_set = cython.cast(set, other)
            collection = other_set
        elif hasattr(other, "__len__"):
            collection = other
            other_set = set(other)
        else:
            collection = list(other)
            other_set = set(collection)
        result: OrderedSet[Union[_T, _S]] = self._from_list(
            [a for a in self._list if a not in other_set]
        )
        result.update([a for a in collection if a not in self])
        return result

    def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
        return self.symmetric_difference(other)

    # @cython.ccall # cdef function cannot have star argument
    def difference(self, *other: Iterable[Hashable]) -> OrderedSet[_T]:
        other_set: Set[Any] = set.difference(self, *other)
        return self._from_list([a for a in self._list if a in other_set])

    def __sub__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]:
        return self.difference(other)

    # @cython.ccall # cdef function cannot have star argument
    def intersection_update(self, *other: Iterable[Hashable]) -> None:
        set.intersection_update(self, *other)
        self._list = [a for a in self._list if a in self]

    def __iand__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]:
        self.intersection_update(other)
        return self

    @cython.ccall
    @cython.annotation_typing(False)  # avoid cython crash from generic return
    def symmetric_difference_update(self, other: Iterable[_T], /) -> None:
        collection = other if hasattr(other, "__len__") else list(other)
        set.symmetric_difference_update(self, collection)
        self._list = [a for a in self._list if a in self]
        self._list += [a for a in collection if a in self]

    def __ixor__(
        self: OrderedSet[Union[_T, _S]], other: AbstractSet[_S]
    ) -> OrderedSet[Union[_T, _S]]:
        self.symmetric_difference_update(other)
        return self

    # @cython.ccall # cdef function cannot have star argument
    def difference_update(self, *other: Iterable[Hashable]) -> None:
        set.difference_update(self, *other)
        self._list = [a for a in self._list if a in self]

    def __isub__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]:
        self.difference_update(other)
        return self


if cython.compiled:

    @cython.cfunc
    @cython.inline
    def _get_id(item: object, /) -> cython.ulonglong:
        return cython.cast(
            cython.ulonglong,
            cython.cast(cython.pointer(cython.void), item),
        )

else:
    _get_id = id


@cython.cclass
class IdentitySet:
    """A set that considers only object id() for uniqueness.

    This strategy has edge cases for builtin types- it's possible to have
    two 'foo' strings in one of these sets, for example.  Use sparingly.

    """

    __slots__ = ("_members",)
    _members: Dict[int, Any]

    def __init__(self, iterable: Optional[Iterable[Any]] = None):
        # the code assumes this class is ordered
        self._members = {}
        if iterable:
            self.update(iterable)

    def add(self, value: Any, /) -> None:
        self._members[_get_id(value)] = value

    def __contains__(self, value) -> bool:
        return _get_id(value) in self._members

    @cython.ccall
    def remove(self, value: Any, /):
        del self._members[_get_id(value)]

    def discard(self, value, /) -> None:
        try:
            self.remove(value)
        except KeyError:
            pass

    def pop(self) -> Any:
        pair: Tuple[Any, Any]
        try:
            pair = self._members.popitem()
            return pair[1]
        except KeyError:
            raise KeyError("pop from an empty set")

    def clear(self) -> None:
        self._members.clear()

    def __eq__(self, other: Any) -> bool:
        other_: IdentitySet
        if isinstance(other, IdentitySet):
            other_ = other
            return self._members == other_._members
        else:
            return False

    def __ne__(self, other: Any) -> bool:
        other_: IdentitySet
        if isinstance(other, IdentitySet):
            other_ = other
            return self._members != other_._members
        else:
            return True

    @cython.ccall
    def issubset(self, iterable: Iterable[Any], /) -> cython.bint:
        other: IdentitySet
        if isinstance(iterable, IdentitySet):
            other = iterable
        else:
            other = self.__class__(iterable)

        return self._members.keys() <= other._members.keys()

    def __le__(self, other: Any) -> bool:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return self.issubset(other)

    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return len(self) < len(other) and self.issubset(other)

    @cython.ccall
    def issuperset(self, iterable: Iterable[Any], /) -> cython.bint:
        other: IdentitySet
        if isinstance(iterable, IdentitySet):
            other = iterable
        else:
            other = self.__class__(iterable)

        return self._members.keys() >= other._members.keys()

    def __ge__(self, other: Any) -> bool:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return self.issuperset(other)

    def __gt__(self, other: Any) -> bool:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return len(self) > len(other) and self.issuperset(other)

    @cython.ccall
    def union(self, iterable: Iterable[Any], /) -> IdentitySet:
        result: IdentitySet = self.__class__()
        result._members.update(self._members)
        result.update(iterable)
        return result

    def __or__(self, other: Any) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return self.union(other)

    @cython.ccall
    def update(self, iterable: Iterable[Any], /):
        members: Dict[int, Any] = self._members
        if isinstance(iterable, IdentitySet):
            members.update(cython.cast(IdentitySet, iterable)._members)
        else:
            for obj in iterable:
                members[_get_id(obj)] = obj

    def __ior__(self, other: Any) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        self.update(other)
        return self

    @cython.ccall
    def difference(self, iterable: Iterable[Any], /) -> IdentitySet:
        result: IdentitySet = self.__new__(self.__class__)
        if isinstance(iterable, IdentitySet):
            other = cython.cast(IdentitySet, iterable)._members.keys()
        else:
            other = {_get_id(obj) for obj in iterable}

        result._members = {
            k: v for k, v in self._members.items() if k not in other
        }
        return result

    def __sub__(self, other: IdentitySet) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return self.difference(other)

    # def difference_update(self, iterable: Iterable[Any]) -> None:
    @cython.ccall
    def difference_update(self, iterable: Iterable[Any], /):
        other: IdentitySet = self.difference(iterable)
        self._members = other._members

    def __isub__(self, other: IdentitySet) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        self.difference_update(other)
        return self

    @cython.ccall
    def intersection(self, iterable: Iterable[Any], /) -> IdentitySet:
        result: IdentitySet = self.__new__(self.__class__)
        if isinstance(iterable, IdentitySet):
            other = cython.cast(IdentitySet, iterable)._members
        else:
            other = {_get_id(obj) for obj in iterable}
        result._members = {
            k: v for k, v in self._members.items() if k in other
        }
        return result

    def __and__(self, other):
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return self.intersection(other)

    # def intersection_update(self, iterable: Iterable[Any]) -> None:
    @cython.ccall
    def intersection_update(self, iterable: Iterable[Any], /):
        other: IdentitySet = self.intersection(iterable)
        self._members = other._members

    def __iand__(self, other: IdentitySet) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        self.intersection_update(other)
        return self

    @cython.ccall
    def symmetric_difference(self, iterable: Iterable[Any], /) -> IdentitySet:
        result: IdentitySet = self.__new__(self.__class__)
        other: Dict[int, Any]
        if isinstance(iterable, IdentitySet):
            other = cython.cast(IdentitySet, iterable)._members
        else:
            other = {_get_id(obj): obj for obj in iterable}
        result._members = {
            k: v for k, v in self._members.items() if k not in other
        }
        result._members.update(
            [(k, v) for k, v in other.items() if k not in self._members]
        )
        return result

    def __xor__(self, other: IdentitySet) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        return self.symmetric_difference(other)

    # def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
    @cython.ccall
    def symmetric_difference_update(self, iterable: Iterable[Any], /):
        other: IdentitySet = self.symmetric_difference(iterable)
        self._members = other._members

    def __ixor__(self, other: IdentitySet) -> IdentitySet:
        if not isinstance(other, IdentitySet):
            return NotImplemented
        self.symmetric_difference(other)
        return self

    @cython.ccall
    def copy(self) -> IdentitySet:
        cp: IdentitySet = self.__new__(self.__class__)
        cp._members = self._members.copy()
        return cp

    def __copy__(self) -> IdentitySet:
        return self.copy()

    def __len__(self) -> int:
        return len(self._members)

    def __iter__(self) -> Iterator[Any]:
        return iter(self._members.values())

    def __hash__(self) -> NoReturn:
        raise TypeError("set objects are unhashable")

    def __repr__(self) -> str:
        return "%s(%r)" % (
            self.__class__.__name__,
            list(self._members.values()),
        )
